[ReplayBuffer] add ReplayBuffer with various StorageBackend#1490
[ReplayBuffer] add ReplayBuffer with various StorageBackend#1490YanhuiDua wants to merge 2 commits intoInternLM:rl_designfrom
Conversation
…aleness, or Database(implement in the future)
There was a problem hiding this comment.
Pull request overview
This PR introduces a new ReplayBuffer abstraction in xtuner/v1/rl/base with pluggable storage backends (e.g., FIFO and staleness-priority), plus initial unit tests for basic FIFO and staleness ordering behavior.
Changes:
- Added
ReplayBuffer,StorageBackendinterface, and multiple backend implementations (FIFOStorageBackend,StalenessStorageBackend, plus stub/pseudocode backends). - Implemented
StorageIndicesto partition storage by(task_name, group_status, tags). - Added async unit tests covering FIFO behavior, staleness priority order, and multi-task isolation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
xtuner/v1/rl/base/replay_buffer.py |
Adds the replay buffer API and backend implementations (FIFO + staleness), with placeholder backends for future extensions. |
tests/ray/test_replay_buffer.py |
Adds async unit tests validating basic replay buffer behavior for FIFO and staleness backends. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| @dataclass | ||
| class StorageIndices: | ||
| # 为不同存储后段提供统一的索引接口 |
There was a problem hiding this comment.
StorageIndices doc comment has a typo: “存储后段” should be “存储后端”.
| # 为不同存储后段提供统一的索引接口 | |
| # 为不同存储后端提供统一的索引接口 |
|
|
||
|
|
||
| class ReplayBuffer: | ||
| def __init__(self, storage_backend: StorageBackend = None): |
There was a problem hiding this comment.
ReplayBuffer.__init__ annotates storage_backend as StorageBackend but defaults it to None, which will fail type checking under mypy’s strict optional rules. Please change the annotation to StorageBackend | None (or Optional[StorageBackend]).
| def __init__(self, storage_backend: StorageBackend = None): | |
| def __init__(self, storage_backend: StorageBackend | None = None): |
| indices = self._hash_storage_indices(storage_indices) | ||
| group_seq_staleness = max([item.seq_staleness for item in items]) |
There was a problem hiding this comment.
StalenessStorageBackend.put will crash if items is empty (max() on empty list). Also, if any seq_staleness falls outside [min_staleness, max_staleness] (defaults are both 0), self._storage[indices][group_seq_staleness] will raise KeyError. Consider explicitly handling empty input (no-op or clear error) and validating/clamping seq_staleness to the configured bucket range (or dynamically creating buckets).
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max([item.seq_staleness for item in items]) | |
| # If there are no items, treat this as a no-op to avoid max() on an empty list. | |
| if not items: | |
| return | |
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max(item.seq_staleness for item in items) | |
| # Clamp staleness into the configured bucket range to avoid KeyError. | |
| group_seq_staleness = max(self.min_staleness, | |
| min(self.max_staleness, group_seq_staleness)) |
| def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0): | ||
| self.limit = limit | ||
| self.max_staleness = max_staleness | ||
| self.min_staleness = min_staleness | ||
| self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)}) | ||
| self._bucket_counts = defaultdict(int) |
There was a problem hiding this comment.
StalenessStorageBackend.__init__ accepts a limit parameter but it is never enforced (items can grow unbounded). Either implement eviction behavior consistent with FIFOStorageBackend(limit=...) or remove the parameter to avoid misleading API.
| new_df = pd.DataFrame(new_rows) | ||
| self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False) | ||
|
|
||
| def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: |
There was a problem hiding this comment.
PandasStorageBackend.get is defined as a synchronous method, but the StorageBackend interface defines async def get(...). Even though the class currently raises in __init__, keeping signatures consistent will prevent accidental misuse later and avoids confusing API expectations (callers will await this).
| def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: | |
| async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: |
| def _hash_storage_indices(self, indices: StorageIndices) -> tuple: | ||
| base = (indices.task_name, indices.group_status) | ||
|
|
||
| if indices.tags: | ||
| sorted_tags = tuple(sorted(indices.tags.items())) | ||
| return base + sorted_tags | ||
| return base |
There was a problem hiding this comment.
_hash_storage_indices builds a tuple used as a dict key; if any tag value is unhashable (e.g., list/dict), this will raise TypeError at runtime. Since tags come from **kwargs, consider validating/coercing tag values to hashable types (e.g., str(value)/json.dumps) or restricting the accepted tag value types.
| # 构建动态查询 | ||
| query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?" | ||
| params = [indices.task_name, indices.group_status] | ||
|
|
||
| # SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤) | ||
| # 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。 | ||
| # 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列) | ||
| # 或者使用 JSON_EXTRACT (推荐) | ||
| for key, value in indices.tags.items(): | ||
| # 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。 | ||
| # $.key 取出对应的值 | ||
| query += f" AND json_extract(tags, '$.{key}') = ?" | ||
| params.append(value) | ||
|
|
||
| query += f" LIMIT {count}" | ||
|
|
There was a problem hiding this comment.
SQLStorageBackend.get builds SQL using f-strings for both the JSON path ($.{key}) and LIMIT {count}. If key is user-controlled (it comes from indices.tags), this is a SQL injection risk once this backend is implemented. Prefer validating key against an allowlist/regex and using parameter binding for LIMIT (and avoid interpolating raw values into the query string).
| tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8 | ||
|
|
There was a problem hiding this comment.
The tags-based partitioning logic is part of the public ReplayBuffer.put/get API (via **kwargs), but there are no tests asserting that different tag values map to different storage partitions and don’t mix. Consider adding a small test that writes items with different tag combinations and verifies isolation.
| class PandasStorageBackend(StorageBackend): | ||
| def __init__(self, limit: int = 0): | ||
| raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.") | ||
| import pandas as pd |
There was a problem hiding this comment.
This statement is unreachable.
| class SQLStorageBackend(StorageBackend): | ||
| def __init__(self, db_path: str = ":memory:"): | ||
| raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.") | ||
| self.db_path = db_path |
There was a problem hiding this comment.
This statement is unreachable.
ReplayBuffer 设计说明
StorageIndices
数据索引类,给不同的后端支持统一的索引的方法
StorageBackend
抽象的存储后端,支持不同类型的存储,例如最简单的FIFO
FIFOStorageBackend, 优先级队列StalenessStorageBackend, 数据库等;并且提供了提供了PandasStorageBackend,SQLStorageBackend的伪代码作为参考;ReplayBufffer